#!/usr/bin/env python3
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

import gc
import io
import json
import os
import sys
import time
from timeit import timeit

import numpy
import psutil
import rapidjson
import simplejson
from memory_profiler import memory_usage
from tabulate import tabulate

import orjson

os.sched_setaffinity(os.getpid(), {0, 1})


kind = sys.argv[1] if len(sys.argv) >= 1 else ""

if kind == "int32":
    array = numpy.random.randint(((2 ** 31) - 1), size=(100000, 100), dtype=numpy.int32)
elif kind == "float64":
    array = numpy.random.random(size=(50000, 100))
    assert array.dtype == numpy.float64
elif kind == "bool":
    array = numpy.random.choice((True, False), size=(100000, 200))
elif kind == "int8":
    array = numpy.random.randint(((2 ** 7) - 1), size=(100000, 100), dtype=numpy.int8)
elif kind == "uint8":
    array = numpy.random.randint(((2 ** 8) - 1), size=(100000, 100), dtype=numpy.uint8)
else:
    print("usage: pynumpy (bool|int32|float64|int8|uint8)")
    sys.exit(1)
proc = psutil.Process()


def default(__obj):
    if isinstance(__obj, numpy.ndarray):
        return __obj.tolist()


headers = ("Library", "Latency (ms)", "RSS diff (MiB)", "vs. orjson")

LIBRARIES = ("orjson", "ujson", "rapidjson", "simplejson", "json")

ITERATIONS = 10

orjson_dumps = lambda: orjson.dumps(array, option=orjson.OPT_SERIALIZE_NUMPY)
ujson_dumps = None
rapidjson_dumps = lambda: rapidjson.dumps(array, default=default).encode("utf-8")
simplejson_dumps = lambda: simplejson.dumps(array, default=default).encode("utf-8")
json_dumps = lambda: json.dumps(array, default=default).encode("utf-8")

output_in_mib = len(orjson_dumps()) / 1024 / 1024

print(f"{output_in_mib:,.1f}MiB {kind} output (orjson)")

gc.collect()
mem_before = proc.memory_full_info().rss / 1024 / 1024


def per_iter_latency(val):
    if val is None:
        return None
    return (val * 1000) / ITERATIONS


def test_correctness(func):
    return orjson.loads(func()) == array.tolist()


table = []
for lib_name in LIBRARIES:
    gc.collect()

    print(f"{lib_name}...")
    func = locals()[f"{lib_name}_dumps"]
    if func is None:
        total_latency = None
        latency = None
        mem = None
        correct = False
    else:
        total_latency = timeit(
            func,
            number=ITERATIONS,
        )
        latency = per_iter_latency(total_latency)
        time.sleep(1)
        mem = max(memory_usage((func,), interval=0.001, timeout=latency * 2))
        correct = test_correctness(func)

    if lib_name == "orjson":
        compared_to_orjson = 1
        orjson_latency = latency
    elif latency:
        compared_to_orjson = latency / orjson_latency
    else:
        compared_to_orjson = None

    if not correct:
        latency = None
        mem = 0

    mem_diff = mem - mem_before

    table.append(
        (
            lib_name,
            f"{latency:,.0f}" if latency else "",
            f"{mem_diff:,.0f}" if mem else "",
            f"{compared_to_orjson:,.1f}" if (latency and compared_to_orjson) else "",
        )
    )

buf = io.StringIO()
buf.write(tabulate(table, headers, tablefmt="grid") + "\n")

print(
    buf.getvalue()
    .replace("-", "")
    .replace("*", "-")
    .replace("=", "-")
    .replace("+", "|")
    .replace("|||||", "")
    .replace("\n\n", "\n")
)
